{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "opencv_image_search.ipynb",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python2",
      "display_name": "Python 2"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "Uav4QVT33-G6",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "import cv2\n",
        "import numpy as np\n",
        "import scipy\n",
        "from scipy.misc import imread\n",
        "import cPickle as pickle\n",
        "import random\n",
        "import os\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "foOzsuUP4QYH",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# Feature extractor\n",
        "def extract_features(image_path, vector_size=32):\n",
        "    image = imread(image_path, mode=\"RGB\")\n",
        "    try:\n",
        "        # Using KAZE, cause SIFT, ORB and other was moved to additional module\n",
        "        # which is adding addtional pain during install\n",
        "        alg = cv2.KAZE_create()\n",
        "        # Dinding image keypoints\n",
        "        kps = alg.detect(image)\n",
        "        # Getting first 32 of them. \n",
        "        # Number of keypoints is varies depend on image size and color pallet\n",
        "        # Sorting them based on keypoint response value(bigger is better)\n",
        "        kps = sorted(kps, key=lambda x: -x.response)[:vector_size]\n",
        "        # computing descriptors vector\n",
        "        kps, dsc = alg.compute(image, kps)\n",
        "        # Flatten all of them in one big vector - our feature vector\n",
        "        dsc = dsc.flatten()\n",
        "        # Making descriptor of same size\n",
        "        # Descriptor vector size is 64\n",
        "        needed_size = (vector_size * 64)\n",
        "        if dsc.size < needed_size:\n",
        "            # if we have less the 32 descriptors then just adding zeros at the\n",
        "            # end of our feature vector\n",
        "            dsc = np.concatenate([dsc, np.zeros(needed_size - dsc.size)])\n",
        "    except cv2.error as e:\n",
        "        print 'Error: ', e\n",
        "        return None\n",
        "\n",
        "    return dsc\n",
        "\n",
        "\n",
        "def batch_extractor(images_path, pickled_db_path=\"features.pck\"):\n",
        "    files = [os.path.join(images_path, p) for p in sorted(os.listdir(images_path))]\n",
        "\n",
        "    result = {}\n",
        "    for f in files:\n",
        "        print 'Extracting features from image %s' % f\n",
        "        name = f.split('/')[-1].lower()\n",
        "        result[name] = extract_features(f)\n",
        "    \n",
        "    # saving all our feature vectors in pickled file\n",
        "    with open(pickled_db_path, 'w') as fp:\n",
        "        pickle.dump(result, fp)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Q20sohfC5Vng",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class Matcher(object):\n",
        "\n",
        "    def __init__(self, pickled_db_path=\"features.pck\"):\n",
        "        with open(pickled_db_path) as fp:\n",
        "            self.data = pickle.load(fp)\n",
        "        self.names = []\n",
        "        self.matrix = []\n",
        "        for k, v in self.data.iteritems():\n",
        "            self.names.append(k)\n",
        "            self.matrix.append(v)\n",
        "        self.matrix = np.array(self.matrix)\n",
        "        self.names = np.array(self.names)\n",
        "\n",
        "    def cos_cdist(self, vector):\n",
        "        # getting cosine distance between search image and images database\n",
        "        v = vector.reshape(1, -1)\n",
        "        return scipy.spatial.distance.cdist(self.matrix, v, 'cosine').reshape(-1)\n",
        "\n",
        "    def match(self, image_path, topn=5):\n",
        "        features = extract_features(image_path)\n",
        "        img_distances = self.cos_cdist(features)\n",
        "        # getting top 5 records\n",
        "        nearest_ids = np.argsort(img_distances)[:topn].tolist()\n",
        "        nearest_img_paths = self.names[nearest_ids].tolist()\n",
        "\n",
        "        return nearest_img_paths, img_distances[nearest_ids].tolist()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "RnQwLPMi7vXp",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def show_img(path):\n",
        "    img = imread(path, mode=\"RGB\")\n",
        "    plt.imshow(img)\n",
        "    plt.show()\n",
        "    \n",
        "def run():\n",
        "    images_path = '../resources/images/'\n",
        "    files = [os.path.join(images_path, p) for p in sorted(os.listdir(images_path))]\n",
        "    # getting 3 random images \n",
        "    sample = random.sample(files, 3)\n",
        "    \n",
        "    batch_extractor(images_path)\n",
        "\n",
        "    ma = Matcher('features.pck')\n",
        "    \n",
        "    for s in sample:\n",
        "        print 'Query image =========================================='\n",
        "        show_img(s)\n",
        "        names, match = ma.match(s, topn=3)\n",
        "        print 'Result images ========================================'\n",
        "        for i in range(3):\n",
        "            # we got cosine distance, less cosine distance between vectors\n",
        "            # more they similar, thus we subtruct it from 1 to get match value\n",
        "            print 'Match %s' % (1-match[i])\n",
        "            show_img(os.path.join(images_path, names[i]))\n",
        "\n",
        "run()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}